import math
import torch
import droid_backends

import sophuspy as sp

from .chol import block_solve, schur_solve, schur_solve_mono_prior, solve_dR, block_solve_imu, schur_solve_imu
import geom.projective_ops as pops

from torch_scatter import scatter_sum


# utility functions for scattering ops
def safe_scatter_add_mat(A, ii, jj, n, m):
    v = (ii >= 0) & (jj >= 0) & (ii < n) & (jj < m)
    return scatter_sum(A[:,v], ii[v]*m + jj[v], dim=1, dim_size=n*m)

def safe_scatter_add_vec(b, ii, n):
    v = (ii >= 0) & (ii < n)
    return scatter_sum(b[:,v], ii[v], dim=1, dim_size=n)

# apply retraction operator to inv-depth maps
def disp_retr(disps, dz, ii):
    ii = ii.to(device=dz.device)
    return disps + scatter_sum(dz, ii, dim=1, dim_size=disps.shape[1])

# apply retraction operator to poses
def pose_retr(poses, dx, ii):
    ii = ii.to(device=dx.device)
    return poses.retr(scatter_sum(dx, ii, dim=1, dim_size=poses.shape[1]))

# apply retraction operator to velocities
def velo_retr(velos, dv, ii):
    ii = ii.to(device=dv.device)
    return velos + scatter_sum(dv, ii, dim=1, dim_size=velos.shape[1])

# apply retraction operator to biases
def bias_retr(biass, dv, ii):
    ii = ii.to(device=dv.device)
    return biass + scatter_sum(dv, ii, dim=1, dim_size=biass.shape[1])

def BA_prepare(target, weight, eta, poses, disps, intrinsics, ii, jj, T_ci_c0=None,
               H=None, v=None, fixedp=1, D=6, t0=None, t1=None):
    """ Construct linear system for Full Bundle Adjustment """

    B, P, ht, wd = disps.shape
    N = ii.shape[0]
    kx, kk = torch.unique(ii, return_inverse=True)
    M = kx.shape[0]

    ### 1: commpute jacobians and residuals ###
    # coords, valid, (Ji, Jj, Jz) = pops.projective_transform(
    #     poses, disps, intrinsics, base, ii, jj, jacobian=True, Tcb=T_ci_c0)
    coords, valid, (Ji, Jj, Jz) = pops.projective_transform_imu(
        poses, disps, intrinsics, ii, jj, jacobian=True, Tcb=T_ci_c0)

    r = (target - coords).view(B, N, -1, 1)
    rw = .001 * (valid * weight).view(B, N, -1, 1)
    chi2 = (rw * r).transpose(2,3) @ r

    chi2R = (rw * r).transpose(2,3) @ r

    ### 2: construct linear system ###
    if D != 6:
        Jnull = torch.cat([torch.zeros_like(Ji), torch.zeros_like(Ji)], dim=-1)[...,:D-6]
        Ji = torch.cat([Ji, Jnull], dim=-1).reshape(B, N, -1, D)
        Jj = torch.cat([Jj, Jnull], dim=-1).reshape(B, N, -1, D)
    else:
        Ji = Ji.reshape(B, N, -1, D)
        Jj = Jj.reshape(B, N, -1, D)
    wJiT = (rw * Ji).transpose(2,3)
    wJjT = (rw * Jj).transpose(2,3)

    Jz = Jz.reshape(B, N, ht*wd, -1)

    Hii = torch.matmul(wJiT, Ji)
    Hij = torch.matmul(wJiT, Jj)
    Hji = torch.matmul(wJjT, Ji)
    Hjj = torch.matmul(wJjT, Jj)

    vi = torch.matmul(wJiT, r).squeeze(-1)
    vj = torch.matmul(wJjT, r).squeeze(-1)

    Ei = (wJiT.view(B,N,D,ht*wd,-1) * Jz[:,:,None]).sum(dim=-1)
    Ej = (wJjT.view(B,N,D,ht*wd,-1) * Jz[:,:,None]).sum(dim=-1)

    rw = rw.view(B, N, ht*wd, -1)
    r = r.view(B, N, ht*wd, -1)
    wk = torch.sum(rw*r*Jz, dim=-1)
    Ck = torch.sum(rw*Jz*Jz, dim=-1)
    
    # only optimize keyframe poses from t0 to t1-1
    if t0 is not None:
        fixedp = t0  # Fix poses before t0
        P = P - fixedp
        ii = ii - fixedp
        jj = jj - fixedp
    else:
        # fallback to original behavior
        P = P - fixedp
        ii = ii - fixedp
        jj = jj - fixedp

    E = safe_scatter_add_mat(Ei, ii, kk, P, M) + \
        safe_scatter_add_mat(Ej, jj, kk, P, M)
    C = safe_scatter_add_vec(Ck, kk, M)
    w = safe_scatter_add_vec(wk, kk, M)
    C += eta.view(*C.shape) + 1e-7
    E = E.view(B, P, M, D, ht*wd)

    if H is None:
        H = safe_scatter_add_mat(Hii, ii, ii, P, P) + \
            safe_scatter_add_mat(Hij, ii, jj, P, P) + \
            safe_scatter_add_mat(Hji, jj, ii, P, P) + \
            safe_scatter_add_mat(Hjj, jj, jj, P, P)
        v = safe_scatter_add_vec(vi, ii, P) + \
            safe_scatter_add_vec(vj, jj, P)
        return H, E, C, v, w, torch.sum(chi2), torch.sum(chi2R)
    else:
        H += safe_scatter_add_mat(Hii, ii, ii, P, P) + \
             safe_scatter_add_mat(Hij, ii, jj, P, P) + \
             safe_scatter_add_mat(Hji, jj, ii, P, P) + \
             safe_scatter_add_mat(Hjj, jj, jj, P, P)
        v += safe_scatter_add_vec(vi, ii, P) + \
             safe_scatter_add_vec(vj, jj, P)
        return E, C, w, torch.sum(chi2), torch.sum(chi2R)


def BA(target, weight, eta, poses, disps, intrinsics, ii, jj, fixedp=1):
    """ Full Bundle Adjustment """

    B, P, ht, wd = disps.shape
    N = ii.shape[0]
    D = poses.manifold_dim

    ### 1: commpute jacobians and residuals ###
    coords, valid, (Ji, Jj, Jz) = pops.projective_transform(
        poses, disps, intrinsics, ii, jj, jacobian=True)

    r = (target - coords).view(B, N, -1, 1)
    w = .001 * (valid * weight).view(B, N, -1, 1)

    ### 2: construct linear system ###
    Ji = Ji.reshape(B, N, -1, D)
    Jj = Jj.reshape(B, N, -1, D)
    wJiT = (w * Ji).transpose(2,3)
    wJjT = (w * Jj).transpose(2,3)

    Jz = Jz.reshape(B, N, ht*wd, -1)

    Hii = torch.matmul(wJiT, Ji)
    Hij = torch.matmul(wJiT, Jj)
    Hji = torch.matmul(wJjT, Ji)
    Hjj = torch.matmul(wJjT, Jj)

    vi = torch.matmul(wJiT, r).squeeze(-1)
    vj = torch.matmul(wJjT, r).squeeze(-1)

    Ei = (wJiT.view(B,N,D,ht*wd,-1) * Jz[:,:,None]).sum(dim=-1)
    Ej = (wJjT.view(B,N,D,ht*wd,-1) * Jz[:,:,None]).sum(dim=-1)

    w = w.view(B, N, ht*wd, -1)
    r = r.view(B, N, ht*wd, -1)
    wk = torch.sum(w*r*Jz, dim=-1)
    Ck = torch.sum(w*Jz*Jz, dim=-1)

    kx, kk = torch.unique(ii, return_inverse=True)
    M = kx.shape[0]

    # only optimize keyframe poses
    P = P - fixedp
    ii = ii - fixedp
    jj = jj - fixedp

    H = safe_scatter_add_mat(Hii, ii, ii, P, P) + \
        safe_scatter_add_mat(Hij, ii, jj, P, P) + \
        safe_scatter_add_mat(Hji, jj, ii, P, P) + \
        safe_scatter_add_mat(Hjj, jj, jj, P, P)

    E = safe_scatter_add_mat(Ei, ii, kk, P, M) + \
        safe_scatter_add_mat(Ej, jj, kk, P, M)

    v = safe_scatter_add_vec(vi, ii, P) + \
        safe_scatter_add_vec(vj, jj, P)

    C = safe_scatter_add_vec(Ck, kk, M)
    w = safe_scatter_add_vec(wk, kk, M)

    C = C + eta.view(*C.shape) + 1e-7

    H = H.view(B, P, P, D, D)
    E = E.view(B, P, M, D, ht*wd)

    ### 3: solve the system ###
    dx, dz, dzcov = schur_solve(H, E, C, v, w)
    
    ### 4: apply retraction ###
    poses = pose_retr(poses, dx, torch.arange(P) + fixedp)
    disps = disp_retr(disps, dz.view(B,-1,ht,wd), kx)

    disps = torch.where(disps > 10, torch.zeros_like(disps), disps)
    disps = disps.clamp(min=0.001)

    return poses, disps, dzcov

# not used in anywere
def MoBA(target, weight, eta, poses, disps, intrinsics, ii, jj, fixedp=1, rig=1):
    """ Motion only bundle adjustment """

    B, P, ht, wd = disps.shape
    N = ii.shape[0]
    D = poses.manifold_dim

    ### 1: commpute jacobians and residuals ###
    coords, valid, (Ji, Jj, Jz) = pops.projective_transform(
        poses, disps, intrinsics, ii, jj, jacobian=True)

    r = (target - coords).view(B, N, -1, 1)
    w = .001 * (valid * weight).view(B, N, -1, 1)

    ### 2: construct linear system ###
    Ji = Ji.reshape(B, N, -1, D)
    Jj = Jj.reshape(B, N, -1, D)
    wJiT = (w * Ji).transpose(2,3)
    wJjT = (w * Jj).transpose(2,3)

    Hii = torch.matmul(wJiT, Ji)
    Hij = torch.matmul(wJiT, Jj)
    Hji = torch.matmul(wJjT, Ji)
    Hjj = torch.matmul(wJjT, Jj)

    vi = torch.matmul(wJiT, r).squeeze(-1)
    vj = torch.matmul(wJjT, r).squeeze(-1)

    # only optimize keyframe poses
    P = P // rig - fixedp
    ii = ii // rig - fixedp
    jj = jj // rig - fixedp

    H = safe_scatter_add_mat(Hii, ii, ii, P, P) + \
        safe_scatter_add_mat(Hij, ii, jj, P, P) + \
        safe_scatter_add_mat(Hji, jj, ii, P, P) + \
        safe_scatter_add_mat(Hjj, jj, jj, P, P)

    v = safe_scatter_add_vec(vi, ii, P) + \
        safe_scatter_add_vec(vj, jj, P)
    
    H = H.view(B, P, P, D, D)

    ### 3: solve the system ###
    dx = block_solve(H, v)

    ### 4: apply retraction ###
    poses = pose_retr(poses, dx, torch.arange(P) + fixedp)
    return poses


def get_prior_depth_aligned(depth_prior, scales):
    M, ht, wd = depth_prior.shape
    hs, ws = scales.shape[-2:]
    meshx, meshy = torch.meshgrid(torch.linspace(0, hs-1-1e-6, ht), torch.linspace(0, ws-1-1e-6, wd), indexing='ij')
    grid = torch.stack((meshy, meshx), -1).cuda()
    grid = grid.unsqueeze(0).expand(M, -1, -1, -1).contiguous()
    mscales_bi, Jbi = droid_backends.bi_inter(scales, grid)
    depth_prior_aligned = depth_prior * mscales_bi
    return depth_prior_aligned, Jbi


def JDSA(target, weight, eta, poses, disps, intrinsics, disps_prior, dscales, ii, jj, alpha):

    B, P, ht, wd = disps.shape
    N = ii.shape[0]

    ### 1: commpute jacobians and residuals ###
    C, w = droid_backends.proj_trans(poses.data.squeeze(), disps[0], intrinsics[0], target, weight, ii, jj)

    kx, kk = torch.unique(ii, return_inverse=True)
    M = kx.shape[0]

    disps_prior = disps_prior[kx]
    m = (disps_prior > 0).to(torch.float).view(-1, ht*wd)

    hs, ws = dscales.shape[-2:]
    disps_bi, Jbi = get_prior_depth_aligned(disps_prior, dscales[kx])

    rd = (disps[0,kx] - disps_bi).view(-1, ht*wd)
    Jd = torch.ones_like(rd).view(1, -1, 1, ht*wd)
    # Jd = (-1. / (disps[0,kx] ** 2)).view(1, -1, 1, ht*wd)
    Jso = -m.unsqueeze(-1) * disps_prior.view(-1, ht*wd).unsqueeze(-1) * Jbi.view(M, ht*wd, -1)[None]

    alpha = torch.ones(M,ht*wd,1).float().cuda() * alpha

    D = hs*ws
    fixedp = kx[0]
    kx = kx - fixedp
    wJsoT = (alpha * Jso).transpose(2,3)
    Hs = safe_scatter_add_mat(wJsoT @ Jso, kx, kx, M, M).view(B, M, M, D, D)
    Es = safe_scatter_add_mat(wJsoT * Jd, kx, kx, M, M).view(B, M, M, D, ht*wd)
    vs = safe_scatter_add_vec(-wJsoT @ rd[None].unsqueeze(-1), kx, M)
    kx += fixedp

    alpha = alpha.squeeze()
    C = C[None] + m * alpha * (Jd * Jd).squeeze() + (1-m) * eta.view(*C.shape)
    w = w[None] - m * alpha * rd * Jd.squeeze()

    ### 3: solve the system ###
    dso, dz, dzcov = schur_solve_mono_prior(C, w, Hs, Es, vs, dzcov=True)

    ### 4: apply retraction ###
    disps = disp_retr(disps, dz.view(B,-1,ht,wd), kx)
    dscales[kx] += dso.view(-1, hs, ws)

    disps = torch.where(disps > 10, torch.zeros_like(disps), disps)
    disps = disps.clamp(min=0.001)

    return disps, dscales, dzcov


#################
# Begin IMU Related Part

import numpy as np
def get_preint_factors(poses_bw, velos_w, biass_w, preints, Rwg, ii, jj, GDir=False, wo_pose=False, scale=1, preint_scale=1e-5):
    Jinti, Jintj, Jgs = [], [], []
    eint, info = [], []
    dT = []
    for i, j in zip(ii, jj):
        id0, id1 = i.item(), j.item()
        inter = preints[(id0,id1)]
        T0_wb = poses_bw[0,id0].inv().matrix().cpu().numpy()
        T1_wb = poses_bw[0,id1].inv().matrix().cpu().numpy()
        V0w = velos_w[0,id0].cpu().numpy()
        V1w = velos_w[0,id1].cpu().numpy()
        inter.set_new_bias(biass_w[0,id0].cpu().numpy())
        err = inter.compute_error(T0_wb, T1_wb, V0w, V1w, Rwg, scale=scale)
        eint.append(err)
        info.append(inter.info)
        Ji, Jj, Jvi, Jvj, Jbg, Jba, Js = inter.jacobian(T0_wb, T1_wb, V0w, V1w, Rwg)
        if wo_pose:
            Jinti.append(np.concatenate([Jvi, Jbg, Jba], axis=-1))
            Jintj.append(np.concatenate([Jvj, np.zeros_like(Jbg), np.zeros_like(Jba)], axis=-1))
        else:
            Jinti.append(np.concatenate([Ji, Jvi, Jbg, Jba], axis=-1))
            Jintj.append(np.concatenate([Jj, Jvj, np.zeros_like(Jbg), np.zeros_like(Jba)], axis=-1))
        dT.append("{:.2f}".format(inter.dT))
        if GDir:
            Jgdir = inter.jacobian_GDir(T0_wb, T1_wb, V0w, V1w, Rwg)
            Jgs.append(np.concatenate([Jgdir, Js], axis=-1))
        # if id0 == 0:
        #     print("- - Id", id0, id1)
        #     print("- - Bias", biass_w[0,id0])
        # print("- Error", list(err), np.linalg.norm(err))
            # print("- - dT:{:.2f} \tmean error:{:.3f} \tmean info/1e9:{:.4f}".format(inter.dT, np.linalg.norm(err), np.linalg.norm(inter.info)/1e9))
            # print("- Jacob", Ji.shape, Jj.shape, Jvi.shape, Jvj.shape, Jbg.shape, Jba.shape)
    info = torch.tensor(np.stack(info, axis=0)[None], dtype=torch.float, device='cuda')
    Jinti = torch.tensor(np.stack(Jinti, axis=0)[None], dtype=torch.float, device='cuda')
    Jintj = torch.tensor(np.stack(Jintj, axis=0)[None], dtype=torch.float, device='cuda')
    eint = torch.tensor(np.stack(eint, axis=0)[None], dtype=torch.float, device='cuda').unsqueeze(-1)
    # print("eint shape", eint.shape)
    chi2 = preint_scale * eint.transpose(2,3) @ info @ eint
    # print("chi2 shape", chi2.shape)
    # print("- - dT", dT)
    # print("- - preint chi2 error: {:.3f}".format(torch.sum(chi2).item()), [f"{x:.3f}" for x in list(chi2[0,:,0,0].cpu().numpy())])
    # print("- - preint chi2 error: {:.5f}".format(torch.sum(chi2).item()))

    wJintiT = preint_scale * torch.matmul(Jinti.transpose(2,3), info)
    wJintjT = preint_scale * torch.matmul(Jintj.transpose(2,3), info)

    Hintii = torch.matmul(wJintiT, Jinti)
    Hintij = torch.matmul(wJintiT, Jintj)
    Hintji = torch.matmul(wJintjT, Jinti)
    Hintjj = torch.matmul(wJintjT, Jintj)

    vinti = torch.matmul(wJintiT, eint).squeeze(-1)
    vintj = torch.matmul(wJintjT, eint).squeeze(-1)

    if GDir:
        Jgs = torch.tensor(np.stack(Jgs, axis=0)[None], dtype=torch.float, device='cuda')
        wJgsT = preint_scale * torch.matmul(Jgs.transpose(2,3), info)
        Hgs = torch.sum(torch.matmul(wJgsT, Jgs), dim=1)
        vgs = torch.sum(torch.matmul(wJgsT, eint), dim=1)
        Hi_gs = torch.matmul(wJintiT, Jgs)
        Hj_gs = torch.matmul(wJintjT, Jgs)
        return Hintii, Hintij, Hintji, Hintjj, vinti, vintj, Hgs, vgs, Hi_gs, Hj_gs, chi2
    return Hintii, Hintij, Hintji, Hintjj, vinti, vintj, chi2

def get_bias_factors(biass_w, preints, ii, jj, wo_pose=False, preint_scale=1e-5):
    Jinti, Jintj = [], []
    eint, info = [], []
    for i, j in zip(ii, jj):
        id0, id1 = i.item(), j.item()
        inter = preints[(id0,id1)]
        Bias0 = biass_w[0,id0]
        Bias1 = biass_w[0,id1]
        err = Bias1 - Bias0
        # print("- - -", i.item(), j.item(), list(err.cpu().numpy()))
        eint.append(err)
        info.append(inter.info2)
        if wo_pose:
            Ji = np.zeros((6,9))
            Ji[:,3:] = np.eye(6)
        else:
            Ji = np.zeros((6,15))
            Ji[:,9:] = np.eye(6)
        Jj = -Ji
        Jinti.append(Ji)
        Jintj.append(Jj)
    info = torch.tensor(np.stack(info, axis=0)[None], dtype=torch.float, device='cuda')
    Jinti = torch.tensor(np.stack(Jinti, axis=0)[None], dtype=torch.float, device='cuda')
    Jintj = torch.tensor(np.stack(Jintj, axis=0)[None], dtype=torch.float, device='cuda')
    eint = torch.stack(eint, dim=0)[None].unsqueeze(-1)

    preint_scale_prior = preint_scale
    wJintiT = preint_scale_prior * torch.matmul(Jinti.transpose(2,3), info)
    wJintjT = preint_scale_prior * torch.matmul(Jintj.transpose(2,3), info)

    Hintii = torch.matmul(wJintiT, Jinti)
    Hintij = torch.matmul(wJintiT, Jintj)
    Hintji = torch.matmul(wJintjT, Jinti)
    Hintjj = torch.matmul(wJintjT, Jintj)

    vinti = torch.matmul(wJintiT, eint).squeeze(-1)
    vintj = torch.matmul(wJintjT, eint).squeeze(-1)
    return Hintii, Hintij, Hintji, Hintjj, vinti, vintj

def get_bias_prior_factors(biass_w, ii, preint_scale=1e-5):
    Jprior, eprior, info = [], [], []
    for i in ii:
        err = biass_w[0, 0] - biass_w[0, i.item()]
        eprior.append(err)
        Ji = np.zeros((6,15))
        Ji[:,9:] = np.eye(6)
        Jprior.append(Ji)
        info.append(np.diag([1e3, 1e3, 1e3, 1e2, 1e2, 1e2]))
    Jprior = torch.tensor(np.stack(Jprior, axis=0)[None], dtype=torch.float, device='cuda')
    eprior = torch.stack(eprior, dim=0)[None].unsqueeze(-1)
    info = torch.tensor(np.stack(info, axis=0)[None], dtype=torch.float, device='cuda')

    wJpriorT = torch.matmul(Jprior.transpose(2,3), info)
    Hprior = torch.matmul(wJpriorT, Jprior)
    vprior = torch.matmul(wJpriorT, eprior).squeeze(-1)
    return Hprior, vprior

def InertialFullBA(t0, t1, poses_bw, velos_w, biass_w, disps, disps2, disps3, ii, preints, Rwg, H, E, C, v, w):
    """ Full Bundle Adjustment """

    B, _, ht, wd = disps.shape
    D = poses_bw.manifold_dim + 3 + 6   # Pose(6) + Vel(3) + Bias gyr and acc(6)
    kx, kk = torch.unique(ii, return_inverse=True)
    M = kx.shape[0]     # number of depth maps to optimize

    ### 2b: add preint imu factors ###
    P = t1 - t0     # number of poses to optimize
    ii = torch.arange(t0-1, t1-1, device='cuda') - t0
    jj = ii + 1
    # print("- - ij in BA", ii, jj)
    # print("- - ij in GR", ii+t0, jj+t0)
    Hintii, Hintij, Hintji, Hintjj, vinti, vintj, chi2 = get_preint_factors(poses_bw, velos_w, biass_w, preints, Rwg, ii+t0, jj+t0)
    # print("- - Chi2 error preint: {:.5f}".format(torch.sum(chi2).item()))

    H += safe_scatter_add_mat(Hintii, ii, ii, P, P) + \
        safe_scatter_add_mat(Hintij, ii, jj, P, P) + \
        safe_scatter_add_mat(Hintji, jj, ii, P, P) + \
        safe_scatter_add_mat(Hintjj, jj, jj, P, P)
    v += safe_scatter_add_vec(vinti, ii, P) + \
        safe_scatter_add_vec(vintj, jj, P)

    ### 2c: add bias consistency factors ###
    Hbii, Hbij, Hbji, Hbjj, vbi, vbj = get_bias_factors(biass_w, preints, ii+t0, jj+t0)
    H += safe_scatter_add_mat(Hbii, ii, ii, P, P) + \
        safe_scatter_add_mat(Hbij, ii, jj, P, P) + \
        safe_scatter_add_mat(Hbji, jj, ii, P, P) + \
        safe_scatter_add_mat(Hbjj, jj, jj, P, P)
    v += safe_scatter_add_vec(vbi, ii, P) + \
        safe_scatter_add_vec(vbj, jj, P)
        
    Hbii, vbi = get_bias_prior_factors(biass_w, ii+t0)
    H += safe_scatter_add_mat(Hbii, ii, ii, P, P)
    v += safe_scatter_add_vec(vbi, ii, P)


    ### 3: solve the system ###
    H = H.view(B, P, P, D, D)
    dxv, dz = schur_solve(H, E, C, v, w)

    ### 4: apply retraction ###
    # print("- - dx", torch.mean(dxv[..., :6]), "; dv", torch.mean(dxv[..., 6:9]), "; dbg", torch.mean(dxv[..., 9:12]), "; dba", torch.mean(dxv[..., 12:]))
    poses_bw = pose_retr(poses_bw, dxv[..., :6], torch.arange(P) + t0)
    velos_w = velo_retr(velos_w, dxv[..., 6:9], torch.arange(P) + t0)
    biass_w = bias_retr(biass_w, dxv[..., 9:], torch.arange(P) + t0)
    disps = disp_retr(disps, dz[:,:M].view(B,-1,ht,wd), kx)
    disps = torch.where(disps > 10, torch.zeros_like(disps), disps)
    disps = disps.clamp(min=0.0)

    if disps2 is not None:
        disps2 = disp_retr(disps2, dz[:,M:2*M].view(B,-1,ht,wd), kx)
        disps2 = torch.where(disps2 > 10, torch.zeros_like(disps2), disps2)
        disps2 = disps2.clamp(min=0.0)
        disps3 = disp_retr(disps3, dz[:,2*M:3*M].view(B,-1,ht,wd), kx)
        disps3 = torch.where(disps3 > 10, torch.zeros_like(disps3), disps3)
        disps3 = disps3.clamp(min=0.0)
    return poses_bw, velos_w, biass_w, disps, disps2, disps3

def InitializeFullInertialBA(t0, t1, poses_bw, velos_w, biass_w, disps, disps2, disps3, ii, preints, Rwg, scale, H, E, C, v, w, fix_front, imu_init_fix_scale, bias_scale=1.0):
    """ Full Bundle Adjustment with both reprojection and inertial factors"""

    B, _, ht, wd = disps.shape
    D = poses_bw.manifold_dim + 3 + 6   # Pose(6) + Vel(3) + Bias gyr and acc(6)

    kx, kk = torch.unique(ii, return_inverse=True)
    M = kx.shape[0]     # number of depth maps to optimize

    ### 2b: add preint imu factors ###
    P = t1 - t0     # number of poses to optimize
    ii = torch.arange(t0  , t1-1, device='cuda') - t0 
    jj = ii + 1
    Hintii, Hintij, Hintji, Hintjj, vinti, vintj, Hgdir, vgdir, Hi_gdir, Hj_gdir, chi2 = get_preint_factors(poses_bw, velos_w, biass_w, preints, Rwg, ii+t0, jj+t0, GDir=True, scale=scale)
    print("- - Chi2 error preint: {:.5f}".format(torch.sum(chi2).item()))
    H += safe_scatter_add_mat(Hintii, ii, ii, P, P) + \
        safe_scatter_add_mat(Hintij, ii, jj, P, P) + \
        safe_scatter_add_mat(Hintji, jj, ii, P, P) + \
        safe_scatter_add_mat(Hintjj, jj, jj, P, P)
    v += safe_scatter_add_vec(vinti, ii, P) + \
        safe_scatter_add_vec(vintj, jj, P)

    ### 2c: add bias consistency factors ###
    Hbii, Hbij, Hbji, Hbjj, vbi, vbj = get_bias_factors(biass_w, preints, ii+t0, jj+t0)
    H += safe_scatter_add_mat(Hbii, ii, ii, P, P) + \
        safe_scatter_add_mat(Hbij, ii, jj, P, P) + \
        safe_scatter_add_mat(Hbji, jj, ii, P, P) + \
        safe_scatter_add_mat(Hbjj, jj, jj, P, P)
    v += safe_scatter_add_vec(vbi, ii, P) + \
        safe_scatter_add_vec(vbj, jj, P)

    ### 3: solve the system ###
    H = H.view(B, P, P, D, D)
    if imu_init_fix_scale:
        Hgdir = None
        dxv, dz = schur_solve_imu(H, E, C, v, w, Hgdir, vgdir, Hi_gdir, Hj_gdir, fix_front=fix_front)
    else:
        dxv, dz, ds = schur_solve_imu(H, E, C, v, w, Hgdir, vgdir, Hi_gdir, Hj_gdir, fix_front=fix_front)

    ### 4: apply retraction ###
    # print("- - dx", torch.mean(dxv[..., :6]), "; dv", torch.mean(dxv[..., 6:9]), "; dbg", torch.mean(dxv[..., 9:12]), "; dba", torch.mean(dxv[..., 12:]))
    poses_bw = pose_retr(poses_bw, dxv[..., :6], torch.arange(P) + t0)
    velos_w = velo_retr(velos_w, dxv[..., 6:9], torch.arange(P) + t0)
    biass_w = bias_retr(biass_w, dxv[..., 9:]*bias_scale, torch.arange(P) + t0)
    disps = disp_retr(disps, dz[:,:M].view(B,-1,ht,wd), kx)
    disps = torch.where(disps > 10, torch.zeros_like(disps), disps)
    disps = disps.clamp(min=0.0)

    if disps2 is not None:
        disps2 = disp_retr(disps2, dz[:,M:2*M].view(B,-1,ht,wd), kx)
        disps2 = torch.where(disps2 > 10, torch.zeros_like(disps2), disps2)
        disps2 = disps2.clamp(min=0.0)
        disps3 = disp_retr(disps3, dz[:,2*M:3*M].view(B,-1,ht,wd), kx)
        disps3 = torch.where(disps3 > 10, torch.zeros_like(disps3), disps3)
        disps3 = disps3.clamp(min=0.0)
    if not imu_init_fix_scale:
        print(ds)
        scale *= math.exp(ds)
    return poses_bw, velos_w, biass_w, disps, disps2, disps3, scale

def InitializeVeloBiasGdir(t0, t1, poses_bw, velos_w, biass_w, preints, Rwg, init_g, Tcb, mean_g, scale, imu_init_fix_scale, fix_front=3, bias_scale=1.0):
    """ Initialize velocities, biases and gravity direction via inertial preintegrate/prior factor """
    B, P = poses_bw.shape
    D = 3 + 6   # Vel(3) + Bias gyr and acc(6)

    ### 1: commpute jacobians and residuals ###
    ### 2b: add preint imu factors ###
    P = t1 - t0
    ii = torch.arange(0, t1-1 - t0, device='cuda')
    jj = ii + 1
    Hintii, Hintij, Hintji, Hintjj, vinti, vintj, Hgdir, vgdir, Hi_gdir, Hj_gdir, chi2 = get_preint_factors(
                                                                    poses_bw, velos_w, biass_w, preints, Rwg, ii+t0, jj+t0, GDir=True, wo_pose=True, scale=scale)
    # print("- - preint chi2 error: {:.5f}".format(torch.sum(chi2).item()))
    H = safe_scatter_add_mat(Hintii, ii, ii, P, P) + \
        safe_scatter_add_mat(Hintij, ii, jj, P, P) + \
        safe_scatter_add_mat(Hintji, jj, ii, P, P) + \
        safe_scatter_add_mat(Hintjj, jj, jj, P, P)
    v = safe_scatter_add_vec(vinti, ii, P) + \
        safe_scatter_add_vec(vintj, jj, P)

    ### 2c: add bias consistency factors ###
    Hbii, Hbij, Hbji, Hbjj, vbi, vbj = get_bias_factors(biass_w, preints, ii+t0, jj+t0, wo_pose=True)
    H += safe_scatter_add_mat(Hbii, ii, ii, P, P) + \
        safe_scatter_add_mat(Hbij, ii, jj, P, P) + \
        safe_scatter_add_mat(Hbji, jj, ii, P, P) + \
        safe_scatter_add_mat(Hbjj, jj, jj, P, P)
    v += safe_scatter_add_vec(vbi, ii, P) + \
        safe_scatter_add_vec(vbj, jj, P)

    H_gba, v_gba = None, None
    if mean_g is not None:
        H_gba, v_gba = ConstantGravityFactor(Rwg, init_g, mean_g, biass_w[0,0,3:].clone(), Tcb)
    ### 3: solve the system ###
    H = H.view(B, P, P, D, D)
    if imu_init_fix_scale:
        Hgdir = Hgdir[:, :-1, :-1]
    # quite different compared to  
    dx, dgs = block_solve_imu(H, v, Hgdir, vgdir, Hi_gdir, Hj_gdir, fix_front=fix_front, H_gba=H_gba, v_gba=v_gba)
    print('dx', dx.shape, 'dgs', dgs.shape)
    ### 4: apply retraction ###
    # print("- - dv", torch.mean(dx[..., :3]), "; dbg", torch.mean(dx[..., 3:6]), "; dba", torch.mean(dx[..., 6:]), "; dRwg", torch.mean(dgs[0,:3]), "; dscale", torch.mean(dgs[0,3]))
    velos_w = velo_retr(velos_w, dx[..., :3], torch.arange(P) + t0)
    biass_w = bias_retr(biass_w, dx[..., 3:]*bias_scale, torch.arange(P) + t0)
    Rwg = Rwg @ sp.SO3.exp(dgs[0,:3].cpu().numpy()).matrix()
    if not imu_init_fix_scale:
        print(dgs)
        scale *= math.exp(dgs[0,3])

    return velos_w, biass_w, Rwg, scale

def InitializeGravityDirectionDynamic(t0, t1, poses_bw, velos_w, biass_w, preints, Rwg):
    """ Initialize gravity direction via inertial preintegrate factor """
    ii = torch.arange(0, t1-1 - t0, device='cuda') 
    jj = ii + 1
    _, _, _, _, _, _, Hgdir, vgdir, _, _, chi2 = get_preint_factors(poses_bw, velos_w, biass_w, preints, Rwg, ii+t0, jj+t0, GDir=True, wo_pose=True)    
    print("- - preint chi2 error: {:.5f}".format(torch.sum(chi2).item()))

    dR = solve_dR(Hgdir, vgdir)
    return Rwg @ sp.SO3.exp(dR[0,:3].cpu().numpy()).matrix()

def ConstantGravityFactor(Rwg, init_g, mean_g, Bias_c, Tcb=None):
    Jprior = torch.zeros((1,3,6), device='cuda')
    if Tcb is not None:
        Bias_c = Tcb.matrix()[0,0,:3,:3] @ Bias_c
        # print("- - prior ba", Bias_c[0,0,3:], prior_ba)
        Jprior[:,:,3:] = -Tcb.matrix()[0,0,:3,:3] @ torch.eye(3, dtype=torch.float, device='cuda')
    else:
        Jprior[:,:,3:] = -torch.eye(3, dtype=torch.float, device='cuda')

    err = -(mean_g - Bias_c.cpu().numpy()) - Rwg @ init_g
    # print("- - Initialize gravity error", err, Bias_c)
    err = torch.tensor(err, dtype=torch.float, device='cuda')[None].unsqueeze(-1)
    
    Gm = torch.zeros((3,2), device='cuda')
    Gm[0,1] = -9.81
    Gm[1,0] = 9.81
    Jprior[:,:,:2] = torch.tensor(Rwg, dtype=torch.float, device='cuda') @ Gm

    H = 1e6 * Jprior.transpose(1,2) @ Jprior
    v = 1e6 * Jprior.transpose(1,2) @ err

    err = Bias_c[None].unsqueeze(-1)
    Jbias = -torch.eye(3, dtype=torch.float, device='cuda').unsqueeze(0)
    H[:,3:,3:] += 1e3 * Jbias.transpose(1,2) @ Jbias
    v[:,3:] += 1e3 * Jbias.transpose(1,2) @ err
    return H, v

def InitializeGravityDirectionStandstill(Rwg, init_g, mean_g, Bias_c):
    """ Initialize gravity direction via averaged gravity vector at standstill """
    H, v = ConstantGravityFactor(Rwg, init_g, mean_g, Bias_c)

    dR = solve_dR(H, v)

    Rwg = Rwg @ sp.SO3.exp(dR[0,:3].cpu().numpy()).matrix()
    Bias_c += dR[0,3:,0]
    return Rwg, Bias_c
